# load required packages
pkg <- c(
  "furrr","tidyverse",  "readxl", "xlsx", "stargazer", "modelsummary", "kableExtra", "estimatr","wesanderson", "kableExtra", "lubridate", "broom", "hrbrthemes", "patchwork", "sandwich", "lmtest", "Amelia", "mediation", "viridis", "wfe", "ggridges", "ZeligChoice", "Zelig", "broom") 

lapply(pkg, require, character.only = TRUE)

# check R version
sessionInfo()

# set option parameters
options("modelsummary_format_numeric_latex" = "plain")

# theme 
axis.title.size = 12

my_theme <-  function(legend.position = "bottom"){   
    p <- theme_bw() + 
    theme(
      legend.position = "bottom",
      legend.text = element_text(size=axis.title.size + 2),
      legend.title = element_blank(),
      plot.subtitle = element_text(size = axis.title.size,
                                face = "bold",
                                color = "#9a9a9a",
                                hjust = 0.5),
      plot.title = element_text(size = axis.title.size + 4,
                                face = "bold",
                                hjust = 0.5,
                                margin = margin(0.5, 0.5, 10, 0.5)),
      plot.margin = margin(0.5, 0.2, 0.2, 0.2, "cm"),
      axis.title.y = element_text(size = axis.title.size + 2,
                                  margin = margin(0, 10, 0, 0)),
      axis.title.x = element_text(size = axis.title.size + 2,
                                  margin = margin(0, 0, 10, 0)),
      axis.text.y = element_text(size = axis.title.size + 2,
                                margin = margin(10, 5, 10, 0)),
      axis.text.x = element_text(size = axis.title.size,
                                margin = margin(5, 10, 10, 0)),
      strip.text = element_text(size = axis.title.size + 4,
                                face = "bold", 
                                margin = margin(20, 10, 10, 10)),
      strip.background = element_blank(),
      panel.grid = element_blank(),
      panel.border = element_rect(colour = "grey20", fill=NA, size=1.25))}


# run bootstrap in parallel
run_simulation <- function(nworkers,nsim,fn_simulation, data, treatment,control, outcome, covariate_ind, covariate_firm, pretreat_covs, covs_include) {
  
# run bootstrap for one batch
run_batch <- function(fn_simulation, nsim, data, treatment,control, outcome, covariate_ind, covariate_firm, pretreat_covs, covs_include) {
    res <- map_dfr(1:nsim, function(i) {fn_simulation(data, treatment,control, outcome, covariate_ind, covariate_firm, pretreat_covs, covs_include)}) 
    return(res)
  }

  res <- future_map_dfr(1:nworkers,function(i) {
              run_batch(fn_simulation,
                        floor(nsim / nworkers),data, treatment,control, outcome, covariate_ind, covariate_firm, pretreat_covs, covs_include)},
              .options = furrr_options(seed = TRUE))
  return(res)
}


# run ordered logistic regression
run_polr <- function(treatment, control, outcome, covs, data, covs_include = TRUE){

  u_df <<- data %>% filter(!!sym(treatment) + !!sym(control) > 0) 

  if(covs_include){
    u_formula <<- paste0(outcome, "~", paste(c(treatment, covs), collapse = "+")) %>% as.formula()
  }else{
    u_formula <<- paste0(outcome, "~", treatment) %>% as.formula()
  }
  print

  u_fit <<- polr(formula = u_formula, data = u_df, Hess=TRUE, method = c("logistic"))

return(list(df = u_df, formula = u_formula, fit = u_fit))
}



# run linear regression
run_lm <- function(treatment, control, outcome, covs, data, covs_include = TRUE){

  u_df <<- data %>% filter(!!sym(treatment) + !!sym(control) > 0) 

  if(covs_include){
    u_formula <<- paste0(outcome, "~", paste(c(treatment, covs), collapse = "+")) %>% as.formula()

  }else{
    u_formula <<- paste0(outcome, "~", treatment) %>% as.formula()
  }

  u_fit <- estimatr::lm_robust(
    formula = u_formula, 
    data = u_df,
    fixed_effects = ~ industry_group,
    clusters = u_df$industry_group)

return(list(df = u_df, formula = u_formula, fit = u_fit))
}



# run logistic regression
run_logit <- function(treatment, control, outcome, covs, data){

  u_df <<- data %>% filter(!!sym(treatment) + !!sym(control) > 0) 

  u_formula <<- paste0(outcome, "~", paste(c(treatment, covs), collapse = "+")) %>% as.formula()

  u_fit <<- glm(formula = u_formula, data = u_df, family = binomial("logit"))

  u_out <- u_fit %>% broom::tidy()

return(list(df = u_df, formula = u_formula, fit = u_fit, out = u_out))
}


# run all logistic models
run_logit_all <- function(outcome1, outcome2, outcome3){

run_logit(
  data = u_sample,
  treat = "u_treatment_firm_us",
  control = "u_treatment_control",
  covs = c(covariate_ind, covariate_firm, pretreat_covs),
  outcome = outcome1) -> out_info1

out_info1$out$model = "US withdrawal"

run_logit(
  data = u_sample,
  treat = "u_treatment_firm_china",
  control = "u_treatment_firm_multiple",
  covs = c(covariate_ind, covariate_firm, pretreat_covs),
  outcome = outcome2) -> out_info2

out_info2$out$model = "China stays"

run_logit(
  data = u_sample,
  treat = "u_treatment_firm_multiple",
  control = "u_treatment_firm_us",
  covs = c(covariate_ind, covariate_firm, pretreat_covs),
  outcome = outcome3) -> out_info3

out_info3$out$model = "Multiple withdrawal"

# 95 CI
bind_rows(out_info1$out, out_info2$out, out_info3$out) %>%
  filter(str_detect(term, "u_treatment")) %>%
  mutate(
      conf.low = estimate - 1.96*std.error,
      conf.high = estimate + 1.96*std.error,
      .width = 0.95) -> est1

# 90 CI
bind_rows(out_info1$out, out_info2$out, out_info3$out) %>%
  filter(str_detect(term, "u_treatment")) %>%
  mutate(
      conf.low = estimate - 1.645*std.error,
      conf.high = estimate + 1.645*std.error,
      .width = 0.9) -> est2

return(list(est1 = est1, est2 = est2, model1= out_info1, model2 = out_info2, model3 = out_info3))
}


# compute predicted probabilities 
compute_qoi <- function(model, treatment){
  # get estimated probability for level 1-3
  model$fit %>%
    predict(., type = "prob")  -> predict_prob
  
  if(length(model$fit$zeta) <= 3){
    colnames(predict_prob) <- c("level1", "level2", "level3") 
  }

  if(length(model$fit$zeta) > 3){
      colnames(predict_prob) <- c("level1", "level2", "level3","level4", "level5", "level6") 
  }

  out = c()
  
  # estimate ATT
  for(i in 1:(length(model$fit$zeta) +1)){
    outcome = colnames(predict_prob)[i]
    
    # prepare the data
    bind_cols(predict_prob, model$fit$model) %>% 
      mutate(tr = as.numeric(!!sym(treatment))) -> df
    
    # model
    # wfe = run_wfe(df, outcome)
    mod = lm(as.formula(paste(outcome, "~tr")), data = df)
    
    # get output
    est = broom::tidy(summary(mod)) %>% dplyr::filter(term == "tr") %>% mutate(n = nrow(df))
    out = bind_rows(out, est)
    
  }
  
  out = as_tibble(out)
  out$names = colnames(predict_prob)
  
  return(out) 
  
}


# plot estimates - ordered logistic regression
plot_orderlogit_estimates <- function(data, term, estimate = "estimate",
  width = 0.5, fatten_point = 2, 
  interval_size_range = c(1,1.5), 
  position = 0.5){
  
  ggplot(
    data, aes(
    y = !!sym(term), x = !!sym(estimate),
    xmin = conf.low.95, xmax = conf.high.95, color = cov)) +
  ggdist::geom_pointinterval(
    data = data,
    width=width,    
    position=position_dodge(position),
    interval_size_range = interval_size_range,
    fatten_point = fatten_point) +
  ggdist::geom_pointinterval(
    data = data,
    aes(
      y = !!sym(term), x = !!sym(estimate),
      xmin = conf.low.90, xmax = conf.high.90, color = cov),
    width=width,    
    position=position_dodge(position),
    interval_size_range = c(1.5,2),
    fatten_point = fatten_point) +
    my_theme()  +
    xlab("Estimates") + 
    scale_color_manual(values = c("#C93312","#6f6f6f","#006ad5", "#E69F00", "#56B4E9", "#009E73","#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
    geom_vline(xintercept = 0, color = '#6f6f6f', linetype = "dashed") +
    ylab("") -> plot

   return(plot) 
}


# plot estimates - logistic regression
plot_logit_estimates <- function(data, term, estimate = "estimate",
  width = 0.5, fatten_point = 2, 
  interval_size_range = c(1,1.5), 
  position = 0.5){
  
  ggplot(
    data %>% filter(.width == 0.95), aes(
    y = !!sym(term), x = !!sym(estimate),
    xmin = conf.low, xmax = conf.high, color = cov)) +
  ggdist::geom_pointinterval(
    data = data,
    width=width,    
    position=position_dodge(position),
    interval_size_range = interval_size_range,
    fatten_point = fatten_point) +
  ggdist::geom_pointinterval(
    data = data %>% filter(.width == 0.9),
    aes(
      y = !!sym(term), x = !!sym(estimate),
      xmin = conf.low, xmax = conf.high, color = cov),
    width=width,    
    position=position_dodge(position),
    interval_size_range = c(1.5,2),
    fatten_point = fatten_point) +
    my_theme()  +
    xlab("Estimates") + 
    scale_color_manual(values = c("#C93312","#6f6f6f","#006ad5", "#E69F00", "#56B4E9", "#009E73","#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
    geom_vline(xintercept = 0, color = '#6f6f6f', linetype = "dashed") +
    ylab("") -> plot

   return(plot) 
}



# plot estimates
plot_ols_estimates <- function(data, term, 
  width = 0.5, fatten_point = 2, 
  interval_size_range = c(0.5,1.5), 
  position = 0.5){
    
  data %>%
    ggplot(., aes(
    y = !!sym(term), x = estimate,
    xmin = conf.low.95, xmax = conf.high.95, color = group)) +
  ggdist::geom_pointinterval(
    width=width,    
    position=position_dodge(position),
    interval_size_range = interval_size_range,
    fatten_point = fatten_point) +
  ggdist::geom_pointinterval(
    aes(
      y = !!sym(term), x = estimate,
      xmin = conf.low.90, xmax = conf.high.90),
    width=width,    
    position=position_dodge(position),
    interval_size_range = c(1.5,2),
    fatten_point = fatten_point) +
    my_theme()  +
    xlab("Estimates") + 
    scale_color_manual(values = c("#6f6f6f","#C93312","#006ad5", "#E69F00", "#56B4E9", "#009E73","#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
    geom_vline(xintercept = 0, color = '#6f6f6f', linetype = "dashed") +
    ylab("") -> plot

   return(plot) 
}


# plot estimates with one CI
plot_estimate <- function(data, term, 
  width = 0.5, fatten_point = 1.5, 
  interval_size_range = c(1,2), 
  position = 0.5){
  data %>%
    ggplot(., aes(
    y = !!sym(term), x = estimate,
    xmin = conf.low, xmax = conf.high, color = cov)) +
  ggdist::geom_pointinterval(
    width=width,    
    position=position_dodge(position),
    interval_size_range = interval_size_range,
    fatten_point = fatten_point) +
    my_theme()  +
    xlab("Estimates") + 
    geom_vline(xintercept = 0, color = '#6f6f6f', linetype = "dashed") +
    ylab("") -> plot

   return(plot) 
}



# add p-value and estimate to the plot
plot_estimate_labels <- function(
  p, two.means, label = "p = {p}", x_labs = "Estimates", y.position = 0.2) {
  p + 
ggdist::geom_pointinterval(
    aes(x = reorder(term, abs(estimate)), y = estimate,
    ymin = conf.low.95, ymax = conf.high.95),
    width=0.5,    
    position=position_dodge(0.5),
    interval_size_range = c(0.5, 1.5),
    fatten_point = 2) +
  ggdist::geom_pointinterval(
    aes(
      x = reorder(term, abs(estimate)), y = estimate,
      ymin = conf.low.90, ymax = conf.high.90),
    width=0.5,    
    position=position_dodge(0.5),
    interval_size_range = c(1.5,2),
    fatten_point = 2) +
  ggprism::add_pvalue(
    two.means, 
    label = label,
    colour = "#a11f1f",
    # bracket.color = "black",
    label.size = 5,
    inherit.aes = FALSE, 
    show.legend = FALSE, 
    coord.flip = TRUE, 
    y.position = y.position) +
    my_theme()  +
    xlab(x_labs) +
    coord_flip()
}


# boostrap within strata
bootstrap_data <- function(data){
  df = data[sample(1:nrow(data), replace = TRUE), ]
  return(df)
}


# block bootstrap
block_bootstrap <- function(data, treatment){
  data %>% 
    group_by(industry, !!sym(treatment)) %>% 
    nest() %>%
    mutate(data = map(data, ~bootstrap_data(.))) %>% 
    unnest(cols = c(data)) -> df
  return(df)
}



# run bootstrap results
bootstrap_orderlogit <- function(
    data, treatment,control, outcome, covariate_ind, covariate_firm, pretreat_covs, covs_include){
  
  # repeat{
  df = block_bootstrap(data, treatment)
  
  # fit model
  if(covs_include){
    out <- tryCatch({
    run_polr(
      treatment = treatment,
      control = control,
      outcome = outcome,
      covs = c(pretreat_covs, covariate_ind, covariate_firm),
      data = df,
      covs_include = covs_include) -> model
    
    # calculate predicted probabilities
    pred = compute_qoi(model, treatment) 
    
    # extract the estimates
    pred %>% dplyr::select(estimate, names) -> out}, 
    error=function(e) {
      e
      out <- data.frame(estimate=NA, names="error")
    })
    
  }else {
      out <- tryCatch({
        run_polr(
        treatment = treatment,
        control = control,
        outcome = outcome,
        covs = c(),
        data = df, 
        covs_include = FALSE) -> model
        
        # without controls
        # calculate predicted probabilities
        pred = compute_qoi(model, treatment)
        
        # extract the estimates
        pred %>% dplyr::select(estimate, names) -> out
        return(out)
        },
        error=function(e) {
          e
          out <- data.frame(estimate=NA, names="error")
          # print("Estimation error")
          return(out)
        })
    
    if(nrow(out)!=1){
      break()
      }
    }
    # }
  return(out)
  
}


# count number of observations
count_obs <- function(data, filter_expr, count_var) {
  data %>% 
    filter({{filter_expr}}) %>% 
    {temp <<- .} %>%
    count({{count_var}}) %>%
    pull(n) -> n_obs
  
  return(c(n_obs, nrow(temp)))
}

# format data
format_data <- function(data, labels) {
  named_data <- setNames(
    list(
      map(data, ~ .[1, ]), 
      map(data, ~ .[2, ]), 
      map(data, ~ .[3, ])
    ), 
    labels
  )
  
  lapply(named_data, function(x) {
    class(x) <- "modelsummary_list"
    x
  })
}


# create latex table
create_latex_table <- function(models, title, output_path) {
  modelsummary(
    models, 
    estimate = c("{estimate}{stars}"),
    title = title,
    output = ".tex") %>%
    save_kable(output_path)
}

# main function to create latex table
get_latex_table <- function(data, labels, n_obs, output_path, title) {

  data %>%
    # filter(cov == "With controls") %>%
    mutate(label = fct_relevel(label, labels)) %>%
    format_datasummary(., n_obs = n_obs) -> models
  
  create_latex_table(
    format_data(models, labels), title, output_path)
}


# format latex output
format_datasummary <- function(df, n_obs){

  ti <- data.frame(
    term = rep("Treatment",3),
    estimate = df$estimate,
    std.error = df$std.error,
    p.value = df$p.value,
    conf.low = df$conf.low,
    conf.high = df$conf.high)

  gl <- data.frame(
    `Firm-level controls` = rep("Yes", length(n_obs)),
    `Individual-level controls` = rep("Yes", length(n_obs)),
    `Num.Obs.` = n_obs)

  mod = list(tidy = ti, glance = gl)
    
  class(mod) <- "modelsummary_list"

  return(mod)
}


# main function to create latex table
get_latex_table_orderLogit <- function(data, labels, n_obs, width, output_path, title, treat) {
  data %>%
    filter(cov == "With controls") %>%
    mutate(label = fct_relevel(label, labels)) %>%
    format_datasummary(., n_obs = n_obs, width = width) -> models
  
  create_latex_table(
    format_data(models, treat), title, output_path)
}

# filter error and warning messages
get_filtered_data <- function(data, label) {
  data %>% filter(!names %in% c("error", "warning")) %>% mutate(label = label)
}

# format latex output
format_datasummary_orderLogit <- function(output, n_obs, width = 0.95,level = "level3"){

  output %>% filter(.width == width & names == level) -> df

  ti <- data.frame(
    term = "Treatment",
    estimate = df$estimate,
    std.error = df$std.error,
    p.value = df$p.value,
    conf.low = df$conf.low,
    conf.high = df$conf.high)

  gl <- data.frame(
    `Firm-level controls` = "Yes",
    `Individual-level controls` = "Yes",
    `Num.Obs.` = n_obs)

  mod = list(tidy = ti, glance = gl)
    
  class(mod) <- "modelsummary_list"

  return(mod)
}



# get ols estimates
get_data <- function(data, label) {

  data$fit %>%
    tidy() %>%
    filter(str_detect(term, "u_treatment")) %>% mutate(label = label) -> out

  return(out)
    }


# plot the result
covariate_balance_atc <- function(x_treat, x_control, weights = NULL) {
  n1 <- nrow(x_treat); n0 <- nrow(x_control)
  if (is.null(weights)) weights <- rep(n0/n1, n1)
  x_treat_w <- x_treat * weights

  Sc <- apply(x_control, 2, sd)
  Xt <- apply(x_treat_w, 2, function(x) sum(x) / n0)
  Xc <- apply(x_control, 2, function(x) sum(x) / n0)

  # compute the inbalance measure
  inbalance <- (Xt - Xc) / Sc
  return(tibble(Variable = names(inbalance),
                `Standardized Difference` = inbalance))
}
